package util;

import org.nd4j.linalg.api.ndarray.INDArray;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;

public class ShowUtilsNormal {
    private static JFrame frame;
    private static JPanel panel;

    public static void visualize(INDArray[] samples,String name) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle(name);
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(2, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();
        for (INDArray sample : samples) {
            if(sample == null || sample.size(0) == 0){
                continue;
            }
            long size = sample.size(0);
            if(size > 0 ){
                for(int i=0;i<size;i++){
                    if (i==8) {
                        break;
                    }
                    panel.add(getImage(sample.getRow(i)));
                }
            }
        }
        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
        //long[] shape = tensor.shape();
        for (int i = 0; i < 784; i++) {
            //System.out.println((255 * (tensor.getDouble(i) + 0.5)));
            int gray = (int) (tensor.getDouble(i)   * 255 );

            // handle out of bounds pixel values
          /*  gray = Math.min(gray, 255);
            gray = Math.max(gray, 0);*/
            bi.getRaster().setSample(i % 28, i / 28, 0, gray);
        }
        ImageIcon orig = new ImageIcon(bi);
        Image imageScaled = orig.getImage().getScaledInstance((int) (5 * 28), (int) (5 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static JLabel imageFromINDArray(INDArray array) {
        // array = array.reshape(28, 28);
        long[] shape = array.shape();
        int height = (int)shape[2];
        int width = (int)shape[3];
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                //System.out.println(array.getDouble(0, 0, y, x));
                int gray = (int) ((array.getDouble(0, 0, y, x)  + 1) * 127.5);

                // handle out of bounds pixel values
                gray = Math.min(gray, 255);
                gray = Math.max(gray, 0);

                image.getRaster().setSample(x, y, 0, gray);
            }
        }
        ImageIcon orig = new ImageIcon(image);
        Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
                Image.SCALE_DEFAULT);
        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
        //return image;
    }
}
